{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dtype = torch.FloatTensor\n",
    "sentence = (\n",
    "    'Lorem ipsum dolor sit amet consectetur adipisicing elit '\n",
    "    'sed do eiusmod tempor incididunt ut labore et dolore magna '\n",
    "    'aliqua Ut enim ad minim veniam quis nostrud exercitation'\n",
    ")\n",
    "word_num = {w:i for i,w in enumerate(list(set(sentence.split(' '))))}\n",
    "num_word = {i:w for i,w in enumerate(list(set(sentence.split(' '))))}\n",
    "n_class = len(word_num)\n",
    "max_len = len(sentence.split())\n",
    "n_hidden = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_batch(sentence):\n",
    "    input_batch = []\n",
    "    target_batch = []\n",
    "    words = sentence.split()\n",
    "    for i,word in enumerate(words[:-1]):\n",
    "        input = [word_num[w] for w in words[:(i+1)]]\n",
    "        input = input + [0] * (max_len-len(input))\n",
    "        target = word_num[words[i+1]]\n",
    "        input_batch.append(np.eye(n_class)[input])\n",
    "        target_batch.append(target)\n",
    "    return Variable(torch.Tensor(input_batch)),Variable(torch.Tensor(target_batch).type(torch.LongTensor))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class BiLSTM(nn.Module):\n",
    "    #RNN layer包含两个参数 第一个是input的embedding_size 第二个是hidden层的输出size\n",
    "    #input_size [seq_len,batch_size,embedding_size] -> [seq_len,batch_size,n_hidden]\n",
    "    def __init__(self):\n",
    "        super(BiLSTM,self).__init__()\n",
    "        self.lstm = nn.LSTM(input_size=n_class,hidden_size=n_hidden,bidirectional=True)\n",
    "        self.W = nn.Parameter(torch.randn(n_hidden*2,n_class).type(dtype))\n",
    "        self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n",
    "        \n",
    "    def forward(self,X):\n",
    "        input = X.transpose(0,1)\n",
    "        hidden_state = Variable(torch.zeros(1*2,len(X),n_hidden))# [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "        cell_state = Variable(torch.zeros(1*2,len(X),n_hidden))# [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "        outputs ,(_,_) = self.lstm(input,(hidden_state,cell_state))\n",
    "        output = outputs[-1]#\n",
    "        model = torch.mm(output,self.W)+self.b\n",
    "        \n",
    "        return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:100 , loss:3.799682\n",
      "epoch:200 , loss:3.548778\n",
      "epoch:300 , loss:3.359399\n",
      "epoch:400 , loss:2.978504\n",
      "epoch:500 , loss:2.664912\n",
      "epoch:600 , loss:2.505166\n",
      "epoch:700 , loss:2.362607\n",
      "epoch:800 , loss:2.260327\n",
      "epoch:900 , loss:2.155129\n",
      "epoch:1000 , loss:2.060641\n",
      "epoch:1100 , loss:1.976329\n",
      "epoch:1200 , loss:1.897064\n",
      "epoch:1300 , loss:1.821422\n",
      "epoch:1400 , loss:1.746405\n",
      "epoch:1500 , loss:1.677236\n",
      "epoch:1600 , loss:1.606566\n",
      "epoch:1700 , loss:1.528436\n",
      "epoch:1800 , loss:1.449849\n",
      "epoch:1900 , loss:1.380469\n",
      "epoch:2000 , loss:1.317745\n",
      "epoch:2100 , loss:1.258311\n",
      "epoch:2200 , loss:1.205084\n",
      "epoch:2300 , loss:1.157622\n",
      "epoch:2400 , loss:1.114956\n",
      "epoch:2500 , loss:1.079390\n",
      "epoch:2600 , loss:1.044229\n",
      "epoch:2700 , loss:1.011358\n",
      "epoch:2800 , loss:0.980562\n",
      "epoch:2900 , loss:0.951746\n",
      "epoch:3000 , loss:0.924680\n",
      "epoch:3100 , loss:0.899149\n",
      "epoch:3200 , loss:0.874813\n",
      "epoch:3300 , loss:0.847159\n",
      "epoch:3400 , loss:0.815027\n",
      "epoch:3500 , loss:0.770562\n",
      "epoch:3600 , loss:1.934865\n",
      "epoch:3700 , loss:1.350491\n",
      "epoch:3800 , loss:1.262559\n",
      "epoch:3900 , loss:1.212480\n",
      "epoch:4000 , loss:1.171452\n",
      "epoch:4100 , loss:1.130753\n",
      "epoch:4200 , loss:1.089600\n",
      "epoch:4300 , loss:1.056194\n",
      "epoch:4400 , loss:1.025585\n",
      "epoch:4500 , loss:0.997198\n",
      "epoch:4600 , loss:0.970725\n",
      "epoch:4700 , loss:0.945746\n",
      "epoch:4800 , loss:0.921673\n",
      "epoch:4900 , loss:0.904026\n",
      "epoch:5000 , loss:0.880316\n",
      "epoch:5100 , loss:0.742061\n",
      "epoch:5200 , loss:0.703118\n",
      "epoch:5300 , loss:0.678341\n",
      "epoch:5400 , loss:0.656741\n",
      "epoch:5500 , loss:0.637852\n",
      "epoch:5600 , loss:0.620572\n",
      "epoch:5700 , loss:0.610418\n",
      "epoch:5800 , loss:0.598991\n",
      "epoch:5900 , loss:0.561000\n",
      "epoch:6000 , loss:0.543111\n",
      "epoch:6100 , loss:0.529312\n",
      "epoch:6200 , loss:0.516854\n",
      "epoch:6300 , loss:0.505152\n",
      "epoch:6400 , loss:0.493977\n",
      "epoch:6500 , loss:0.483200\n",
      "epoch:6600 , loss:0.472701\n",
      "epoch:6700 , loss:0.462176\n",
      "epoch:6800 , loss:0.450540\n",
      "epoch:6900 , loss:0.435000\n",
      "epoch:7000 , loss:0.415669\n",
      "epoch:7100 , loss:0.545475\n",
      "epoch:7200 , loss:0.411075\n",
      "epoch:7300 , loss:0.384173\n",
      "epoch:7400 , loss:0.369039\n",
      "epoch:7500 , loss:0.357369\n",
      "epoch:7600 , loss:0.347271\n",
      "epoch:7700 , loss:0.338386\n",
      "epoch:7800 , loss:0.330825\n",
      "epoch:7900 , loss:0.338809\n",
      "epoch:8000 , loss:0.326624\n",
      "epoch:8100 , loss:0.318184\n",
      "epoch:8200 , loss:0.310723\n",
      "epoch:8300 , loss:0.305594\n",
      "epoch:8400 , loss:0.300913\n",
      "epoch:8500 , loss:0.377754\n",
      "epoch:8600 , loss:0.307103\n",
      "epoch:8700 , loss:0.289883\n",
      "epoch:8800 , loss:0.285760\n",
      "epoch:8900 , loss:0.294378\n",
      "epoch:9000 , loss:0.295360\n",
      "epoch:9100 , loss:0.279004\n",
      "epoch:9200 , loss:0.274904\n",
      "epoch:9300 , loss:0.271671\n",
      "epoch:9400 , loss:0.268678\n",
      "epoch:9500 , loss:0.265786\n",
      "epoch:9600 , loss:0.305895\n",
      "epoch:9700 , loss:0.467708\n",
      "epoch:9800 , loss:0.270481\n",
      "epoch:9900 , loss:0.265485\n",
      "epoch:10000 , loss:0.260583\n"
     ]
    }
   ],
   "source": [
    "input_batch,target_batch = make_batch(sentence)\n",
    "model = BiLSTM()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(),lr = 0.001)\n",
    "\n",
    "for epoch in range(10000):\n",
    "    optimizer.zero_grad()\n",
    "    output = model(input_batch)\n",
    "    loss = criterion(output,target_batch)\n",
    "    if (epoch+1) % 100 == 0:\n",
    "        print('epoch:%d , loss:%f' % (epoch+1,loss))\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "res = model(input_batch).data.max(1,keepdim=True)[1]\n",
    "print(sentence)\n",
    "print([num_word[n.item()] for n in res.squeeze()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
